Skip to content

[None][feat] Update the deepseek routing#13186

Open
ChristinaZ wants to merge 2 commits intoNVIDIA:mainfrom
ChristinaZ:update_routing_deepseek
Open

[None][feat] Update the deepseek routing#13186
ChristinaZ wants to merge 2 commits intoNVIDIA:mainfrom
ChristinaZ:update_routing_deepseek

Conversation

@ChristinaZ
Copy link
Copy Markdown
Collaborator

@ChristinaZ ChristinaZ commented Apr 19, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Extended expert model support to accommodate up to 1024 experts
    • Increased top-k expert selection limits for greater flexibility
  • Bug Fixes

    • Relaxed top-k selection constraints to improve configuration compatibility
    • Enhanced configuration validation with detailed error messages for expert parallelism setup
  • Tests

    • Expanded test coverage for expert routing scenarios

Description

Update the customized kernel and its application range of Deepseekv3Routing.

Test Coverage

pytest tests/unittest/_torch/thop/parallel/test_noaux_tc.py -v

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@ChristinaZ ChristinaZ requested a review from a team as a code owner April 19, 2026 08:56
@ChristinaZ ChristinaZ requested a review from mikeiovine April 19, 2026 08:56
@ChristinaZ ChristinaZ requested a review from Wanli-Jiang April 19, 2026 08:56
@yweng0828
Copy link
Copy Markdown
Collaborator

Summary: Supported-Configuration Changes Before vs. After PR #13186

This PR broadens the range of MoE routing configurations that can use the fused deepseek_v3_topk_kernel instead of falling back to the PyTorch reference, and removes all model-specific hardcoding from both the host-side dispatch and the kernel body.

1. n_group == 1 branch (single-group routing)

Dimension Before After Net change
num_experts ≤ 512 (Python gate), but combined with the (topk_group == 1 && top_k != 22) fallback clause, only (num_experts, top_k) = (512, 22) actually reached the fused path ≤ 1024 Virtually the entire num_experts ∈ [1, 1024] space is newly opened
top_k Effectively only == 22 (Nemotron Super v3 special case). All other top_k values — including top_k ≤ 8 — fell back to PyTorch because topk_group == 1 combined with top_k != 22 forced fallback ≤ 32 Newly opened: {1..21, 23..32}, i.e. {1..32} \ {22}
topk_group Non-trivial combined constraint that excluded all topk_group == 1 configs except top_k == 22 No effective constraint (since n_group == 1 implies topk_group == 1 anyway)

In practice, the n_group == 1 fused path went from a single supported point (Nemotron Super v3, (512, 1, 1, 22)) to essentially the full practical space (any num_experts ≤ 1024, any top_k ≤ 32). Models like Kimi K2 ((384, 1, 1, 8)) and any ≤ 1024-expert single-group model now benefit from the fused kernel.

2. n_group > 1 branch (grouped routing)

Dimension Before After Net change
num_experts ≤ 256 ≤ 256 unchanged
experts_per_group ≤ 32 (= WARP_SIZE) ≤ 32 unchanged
top_k ≤ 8 (implicit via MaxNumTopExperts default) ≤ 8 (now explicit in the host dispatch) unchanged
experts_per_group × topk_group ≤ 128 ≤ 256 Doubled; new zone is (128, 256]

Within the num_experts ≤ 256, experts_per_group ≤ 32 envelope, the only axis that was expanded is experts_per_group × topk_group, and the expansion happens by doubling the ceiling from 128 to 256. This is implemented by templating MaxNumTopGroups (previously hardcoded to 4) and instantiating a second kernel variant with MaxNumTopGroups = 8 for the new upper band.

DeepSeek-V3's production config (256, 8, 4, 8) is at the old ceiling (32 × 4 = 128) and still takes the default (MaxNumTopGroups = 4) path. New configurations such as (224, 8, 7, 8) (product 196) or (256, 8, 8, 8) (product 256) now dispatch to the MaxNumTopGroups = 8 instance.

@yweng0828 yweng0828 force-pushed the update_routing_deepseek branch from b100f3b to 24dec37 Compare April 24, 2026 17:52
@yweng0828
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

Refactors top-k sorting logic by implementing a complete sorting network, updates MoE kernel constants and dispatch predicates, adds validation guards for intranode expert constraints, updates routing compatibility gating, and extends test parameterization for deepseek configurations.

Changes

Cohort / File(s) Summary
Top-K Sorting Implementation
cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
Replaces partial Sort declaration with complete sorting-network implementation, relaxes top-k constraint from K < kWARP_SIZE to K <= kWARP_SIZE, tightens candidate bound to N <= 32, and consolidates reduceTopK into single function that always sorts before reduction.
MoE Kernel Infrastructure
cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
Replaces model-specific expert constants with MaxSupportedExpertCount=1024 and MaxSupportedTopExperts=32, adds MaxNumTopGroups template parameter to kernel, refactors group/non-group control paths with localized memory allocation, and updates dispatch predicates based on expert count thresholds.
Intranode Validation
tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py
Adds explicit early validation in DeepEP.__init__ enforcing num_slots % moe_ep_size == 0 and per-rank expert limit of 128, raising RuntimeError with parameter details on constraint violation.
Routing Compatibility Gating
tensorrt_llm/_torch/modules/fused_moe/routing.py
Updates Deepseekv3RoutingImpl.noaux_tc fusion gating logic with revised thresholds: for multi-group uses num_experts > 256 and experts_per_group \* topk_group > 256; for single-group removes special top_k == 22 handling and adopts num_experts > 1024 or top_k > 32 constraint.
Test Coverage Extension
tests/unittest/_torch/thop/parallel/test_noaux_tc.py
Extends test_noaux_tc_run parameterization with additional routing configuration tuples to increase coverage of expert/group/topk combinations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title '[None][feat] Update the deepseek routing' is directly related to the main changes which update DeepSeek routing kernel application ranges and constraints.
Description check ✅ Passed The PR description explains the core change ('Update the customized kernel and its application range of Deepseekv3Routing'), provides test coverage details, and includes a completed checklist. However, it lacks specific technical details about constraint changes and impact.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unittest/_torch/thop/parallel/test_noaux_tc.py (1)

8-19: Good coverage expansion, but consider adding a case for the new MaxNumTopGroups=8 kernel path.

The new parameterization covers:

  • Single-group: (72,1,1,6), (384,1,1,8), (512,1,1,22), (1024,1,1,32), (512,1,1,32) — all within the updated limits
  • Multi-group: (256,8,4,8), (256,8,2,8) — uses DefaultMaxNumTopGroups=4 path
  • Fallback: (512,8,6,8) — correctly falls back since experts_per_group=64 > 32

However, there's no test case that exercises the new LargeMaxNumTopGroups=8 kernel path introduced in noAuxTcKernels.cu. This path is triggered when experts_per_group * topk_group > 128 but ≤ 256.

Consider adding a case like (256, 8, 6, 8) which has experts_per_group=32 and topk_group=6, giving a product of 192 that would dispatch to the MaxNumTopGroups=8 variant.

💡 Suggested addition
 `@pytest.mark.parametrize`(
     "num_experts, n_group, topk_group, top_k",
     [
         (256, 8, 4, 8),
         (72, 1, 1, 6),
         (384, 1, 1, 8),
         (512, 1, 1, 22),
         (1024, 1, 1, 32),
         (512, 1, 1, 32),
         (256, 8, 2, 8),
+        (256, 8, 6, 8),  # LargeMaxNumTopGroups path (32*6=192 > 128)
         (512, 8, 6, 8),  # fallback
     ])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/thop/parallel/test_noaux_tc.py` around lines 8 - 19,
Add a new parameterized test case to exercise the MaxNumTopGroups=8 kernel path
by inserting the tuple (256, 8, 6, 8) into the pytest.mark.parametrize list in
tests/unittest/_torch/thop/parallel/test_noaux_tc.py; this tuple yields
experts_per_group=32 and topk_group=6 (product 192) which triggers the
LargeMaxNumTopGroups=8 path in noAuxTcKernels.cu, so ensure it is placed among
the existing cases in the parameter list used by the test function.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unittest/_torch/thop/parallel/test_noaux_tc.py`:
- Around line 8-19: Add a new parameterized test case to exercise the
MaxNumTopGroups=8 kernel path by inserting the tuple (256, 8, 6, 8) into the
pytest.mark.parametrize list in
tests/unittest/_torch/thop/parallel/test_noaux_tc.py; this tuple yields
experts_per_group=32 and topk_group=6 (product 192) which triggers the
LargeMaxNumTopGroups=8 path in noAuxTcKernels.cu, so ensure it is placed among
the existing cases in the parameter list used by the test function.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 6aa31ba0-bc47-4a08-8b35-dab5e3be5618

📥 Commits

Reviewing files that changed from the base of the PR and between 5fd0a6c and 24dec37.

📒 Files selected for processing (5)
  • cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
  • cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
  • tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py
  • tensorrt_llm/_torch/modules/fused_moe/routing.py
  • tests/unittest/_torch/thop/parallel/test_noaux_tc.py

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45430 [ run ] triggered by Bot. Commit: 24dec37 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45430 [ run ] completed with state SUCCESS. Commit: 24dec37
/LLM/main/L0_MergeRequest_PR pipeline #35663 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yweng0828
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45587 [ run ] triggered by Bot. Commit: 24dec37 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45587 [ run ] completed with state SUCCESS. Commit: 24dec37
/LLM/main/L0_MergeRequest_PR pipeline #35803 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

ChristinaZ and others added 2 commits April 26, 2026 20:51
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Signed-off-by: Yue Weng <25103990+yweng0828@users.noreply.github.com>
@yweng0828 yweng0828 force-pushed the update_routing_deepseek branch from 24dec37 to db3381f Compare April 27, 2026 03:52
@yweng0828
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45642 [ run ] triggered by Bot. Commit: db3381f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45642 [ run ] completed with state SUCCESS. Commit: db3381f
/LLM/main/L0_MergeRequest_PR pipeline #35855 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yweng0828
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45739 [ run ] triggered by Bot. Commit: db3381f Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants